{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Making our RNN state of the art"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from fastai2.text.all import *\n",
    "path = untar_data(URLs.HUMAN_NUMBERS)\n",
    "lines = L()\n",
    "with open(path/'train.txt') as f: lines += L(*f.readlines())\n",
    "with open(path/'valid.txt') as f: lines += L(*f.readlines())\n",
    "text = ' . '.join([l.strip() for l in lines])\n",
    "tokens = text.split(' ')\n",
    "vocab = L(*tokens).unique()\n",
    "word2idx = {w:i for i,w in enumerate(vocab)}\n",
    "nums = L(word2idx[i] for i in tokens)\n",
    "\n",
    "def group_chunks(ds, bs):\n",
    "    m = len(ds) // bs\n",
    "    new_ds = L()\n",
    "    for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))\n",
    "    return new_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "sl,bs = 16,64\n",
    "seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1])) for i in range(0,len(nums)-sl-1,sl))\n",
    "cut = int(len(seqs) * 0.8)\n",
    "dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs), group_chunks(seqs[cut:], bs), bs=bs, drop_last=True, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multilayer RNNs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LMModel5(Module):\n",
    "    def __init__(self, vocab_sz, n_hidden, n_layers):\n",
    "        self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
    "        self.rnn = nn.RNN(n_hidden, n_hidden, n_layers, batch_first=True)\n",
    "        self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
    "        self.h = torch.zeros(n_layers, bs, n_hidden)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        res,h = self.rnn(self.i_h(x), self.h)\n",
    "        self.h = h.detach()\n",
    "        return self.h_o(res)\n",
    "    \n",
    "    def reset(self): self.h.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>3.048115</td>\n",
       "      <td>2.622384</td>\n",
       "      <td>0.434001</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.136388</td>\n",
       "      <td>1.763967</td>\n",
       "      <td>0.471191</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.689246</td>\n",
       "      <td>1.898718</td>\n",
       "      <td>0.364746</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.443545</td>\n",
       "      <td>1.747440</td>\n",
       "      <td>0.480387</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.271023</td>\n",
       "      <td>1.870939</td>\n",
       "      <td>0.479980</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>1.101259</td>\n",
       "      <td>1.794428</td>\n",
       "      <td>0.495361</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.948380</td>\n",
       "      <td>1.769644</td>\n",
       "      <td>0.511149</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.822373</td>\n",
       "      <td>1.800406</td>\n",
       "      <td>0.535400</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.731188</td>\n",
       "      <td>1.914065</td>\n",
       "      <td>0.522461</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.662659</td>\n",
       "      <td>1.987547</td>\n",
       "      <td>0.525798</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.613053</td>\n",
       "      <td>2.022102</td>\n",
       "      <td>0.527751</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.577007</td>\n",
       "      <td>2.068472</td>\n",
       "      <td>0.526530</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.551144</td>\n",
       "      <td>2.113533</td>\n",
       "      <td>0.521566</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.535356</td>\n",
       "      <td>2.123089</td>\n",
       "      <td>0.523600</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.526783</td>\n",
       "      <td>2.122413</td>\n",
       "      <td>0.524333</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn = Learner(dls, LMModel5(len(vocab), 64, 2), loss_func=CrossEntropyLossFlat(), metrics=accuracy, cbs=ModelReseter)\n",
    "learn.fit_one_cycle(15, 3e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Handling exploding or disappearing activations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LSTM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Building an LSTM from scratch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMCell(Module):\n",
    "    def __init__(self, ni, nh):\n",
    "        self.forget_gate = nn.Linear(ni + nh, nh)\n",
    "        self.input_gate  = nn.Linear(ni + nh, nh)\n",
    "        self.cell_gate   = nn.Linear(ni + nh, nh)\n",
    "        self.output_gate = nn.Linear(ni + nh, nh)\n",
    "\n",
    "    def forward(self, input, state):\n",
    "        h,c = state\n",
    "        h = torch.stack([x, input], dim=1)\n",
    "        forget = torch.sigmoid(self.forget_gate(h))\n",
    "        c = c * forget\n",
    "        inp = torch.sigmoid(self.input_gate(h))\n",
    "        cell = torch.tanh(self.cell_gate(h))\n",
    "        c = c + inp * cell\n",
    "        out = torch.sigmoid(self.output_gate(h))\n",
    "        h = outgate * torch.tanh(c)\n",
    "        return h, (h,c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMCell(Module):\n",
    "    def __init__(self, ni, nh):\n",
    "        self.ih = nn.Linear(ni,4*nh)\n",
    "        self.hh = nn.Linear(nh,4*nh)\n",
    "\n",
    "    def forward(self, input, state):\n",
    "        h,c = state\n",
    "        #One big multiplication for all the gates is better than 4 smaller ones\n",
    "        gates = (self.ih(input) + self.hh(h)).chunk(4, 1)\n",
    "        ingate,forgetgate,outgate = map(torch.sigmoid, gates[:3])\n",
    "        cellgate = gates[3].tanh()\n",
    "\n",
    "        c = (forgetgate*c) + (ingate*cellgate)\n",
    "        h = outgate * c.tanh()\n",
    "        return h, (h,c)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training a language model using LSTMs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LMModel6(Module):\n",
    "    def __init__(self, vocab_sz, n_hidden, n_layers):\n",
    "        self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
    "        self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
    "        self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
    "        self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n",
    "        \n",
    "    def forward(self, x):\n",
    "        res,h = self.rnn(self.i_h(x), self.h)\n",
    "        self.h = [h_.detach() for h_ in h]\n",
    "        return self.h_o(res)\n",
    "    \n",
    "    def reset(self): \n",
    "        for h in self.h: h.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>3.031346</td>\n",
       "      <td>2.749381</td>\n",
       "      <td>0.279215</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.219651</td>\n",
       "      <td>2.084450</td>\n",
       "      <td>0.204753</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.659518</td>\n",
       "      <td>1.685639</td>\n",
       "      <td>0.479574</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.410550</td>\n",
       "      <td>1.666663</td>\n",
       "      <td>0.509440</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.204062</td>\n",
       "      <td>1.606485</td>\n",
       "      <td>0.541829</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>1.021459</td>\n",
       "      <td>1.529109</td>\n",
       "      <td>0.592448</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.785871</td>\n",
       "      <td>1.340280</td>\n",
       "      <td>0.642008</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.547519</td>\n",
       "      <td>1.271710</td>\n",
       "      <td>0.688802</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.339775</td>\n",
       "      <td>1.216605</td>\n",
       "      <td>0.753825</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.197550</td>\n",
       "      <td>1.218557</td>\n",
       "      <td>0.743652</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.114297</td>\n",
       "      <td>1.253571</td>\n",
       "      <td>0.751139</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.071301</td>\n",
       "      <td>1.314827</td>\n",
       "      <td>0.752686</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.049507</td>\n",
       "      <td>1.307375</td>\n",
       "      <td>0.765462</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.038810</td>\n",
       "      <td>1.287779</td>\n",
       "      <td>0.767741</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.033738</td>\n",
       "      <td>1.292951</td>\n",
       "      <td>0.767985</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn = Learner(dls, LMModel6(len(vocab), 64, 2), loss_func=CrossEntropyLossFlat(), metrics=accuracy, cbs=ModelReseter)\n",
    "learn.fit_one_cycle(15, 1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Regularizing an LSTM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dropout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Dropout(Module):\n",
    "    def __init__(self, p): self.p = p\n",
    "    def forward(self, x):\n",
    "        if self.training: return x\n",
    "        mask = x.new(*x.shape).bernoulli_(1-p)\n",
    "        return x * mask.div_(1-p)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### AR and TAR regularization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training a regularized LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LMModel7(Module):\n",
    "    def __init__(self, vocab_sz, n_hidden, n_layers, p):\n",
    "        self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
    "        self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
    "        self.drop = nn.Dropout(p)\n",
    "        self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
    "        self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n",
    "        \n",
    "    def forward(self, x):\n",
    "        raw,h = self.rnn(self.i_h(x), self.h)\n",
    "        out = self.drop(raw)\n",
    "        self.h = [h_.detach() for h_ in h]\n",
    "        return self.h_o(out),raw,out\n",
    "    \n",
    "    def reset(self): \n",
    "        for h in self.h: h.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(dls, LMModel7(len(vocab), 64, 2, 0.4), loss_func=CrossEntropyLossFlat(), \n",
    "                metrics=accuracy, cbs=[ModelReseter, RNNRegularizer(alpha=2, beta=1)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>3.145553</td>\n",
       "      <td>2.495994</td>\n",
       "      <td>0.437581</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.333189</td>\n",
       "      <td>1.674463</td>\n",
       "      <td>0.491862</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.678753</td>\n",
       "      <td>1.500536</td>\n",
       "      <td>0.553955</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.111904</td>\n",
       "      <td>1.040109</td>\n",
       "      <td>0.748779</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.707829</td>\n",
       "      <td>0.773369</td>\n",
       "      <td>0.807699</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.465899</td>\n",
       "      <td>0.621159</td>\n",
       "      <td>0.829346</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.335249</td>\n",
       "      <td>0.649926</td>\n",
       "      <td>0.839193</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.254418</td>\n",
       "      <td>0.586989</td>\n",
       "      <td>0.841064</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.205191</td>\n",
       "      <td>0.527288</td>\n",
       "      <td>0.850179</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.172876</td>\n",
       "      <td>0.460011</td>\n",
       "      <td>0.868652</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.151452</td>\n",
       "      <td>0.500604</td>\n",
       "      <td>0.860677</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.136872</td>\n",
       "      <td>0.480342</td>\n",
       "      <td>0.863525</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.127576</td>\n",
       "      <td>0.496534</td>\n",
       "      <td>0.858398</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.122187</td>\n",
       "      <td>0.475931</td>\n",
       "      <td>0.867025</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.119538</td>\n",
       "      <td>0.490366</td>\n",
       "      <td>0.861165</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn = TextLearner(dls, LMModel7(len(vocab), 64, 2, 0.4), loss_func=CrossEntropyLossFlat(), \n",
    "                metrics=accuracy)\n",
    "learn.fit_one_cycle(15, 1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
